""" This file provides functions that produce the visualizations for the
paper, including figures and tables.
"""

try:
    from src.interpolation.splines import UCGrid
except:
    from ..interpolation.splines import UCGrid
import scipy
import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib.lines import Line2D
import statsmodels.formula.api as smf
import pandas as pd
import numpy as np
import itertools
from src.models_new import model_objects
from src.models_new import  simulation, surplus

from scipy.stats import ttest_ind
import copy
from math import floor, log10
from scipy.optimize import linear_sum_assignment


try:
    from src.interpolation.splines import eval_linear
except:
    from ..interpolation.splines import eval_linear
from src.models_new import states


def counterfactual_graph(series, comparison, settings, output_path):
    ''' This does the graphs for the counterfactuals.

    Inputs:
        series (dict): the series used in the graph. This should contain
            the following keys:
                - initial
                - entry (optional)
                - quality
                - final
                - boom
                - gas_price

        comparison (str): either 'initial' or 'final' - what the
            comparison of the difference should be made with respect to.

        settings (dict): settings for the plots. This should contain:
            'ylim' (dict): settings for y axis limits in form a_b where
                a in ['left', right] for which graph and b in ['0', '1']
                for which axis.
            'yticks' (dict): settings for y axis ticks in form a_b where
                a in ['left', right] for which graph and b in ['0', '1']
                for which axis.


    Outputs:
        None (does save the figure to ./reports/figures/)
    '''

    # Read in the data
    diff_init = dict()
    diff_init['total'] = 100 * (series['final'] - series['initial']) / series[comparison].mean()
    diff_init['entry'] = 100 * (series['entry'] - series['initial']) / series[comparison].mean()
    diff_init['quality'] = 100 * (series['quality'] - series['entry']) / series[comparison].mean()
    diff_init['quantity'] = 100 * (series['final'] - series['quality']) / series[comparison].mean()

    diff = diff_init.copy()

    diff['quantity'].loc[
            (
                ((diff_init['entry'] > 0) & (diff_init['quantity'] > 0)) |
                ((diff_init['entry'] < 0) & (diff_init['quantity'] < 0))
            )
        ] = diff_init['entry'] + diff_init['quantity']

    diff['quality'].loc[
            (
                ((diff_init['quantity'] > 0) & (diff_init['quality'] > 0)) |
                ((diff_init['quantity'] < 0) & (diff_init['quality'] < 0))
            )
        ] = diff['quantity'] + diff_init['quality']
    diff['quality'].loc[
            (
                ((diff_init['entry'] > 0) & (diff_init['quality'] > 0)) |
                ((diff_init['entry'] < 0) & (diff_init['quality'] < 0))
            )
        ] = diff_init['entry'] + diff['quality']

    # Make the settings for the graph
    plt.rcParams['text.latex.preamble']=r"\usepackage{lmodern}" # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)
    plt.rc('font', size=9)          # controls default text sizes
    plt.rc('axes', titlesize=9)     # fontsize of the axes title
    plt.rc('axes', labelsize=9)     # fontsize of the x and y labels
    plt.rc('xtick', labelsize=9)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=9)    # fontsize of the tick labels
    plt.rc('legend', fontsize=8)    # legend fontsize
    plt.rc('figure', titlesize=10)  # figure title size

    # Hard code some settings
    settings['colors'] = {
        'total': 'dodgerblue',
        'quantity': 'darkorange',
        'quality': 'silver',
        'entry': 'red'
    }
    settings['labels'] = {
        'total': 'Total effect',
        'quantity': 'Quantity effect',
        'quality': 'Quality effect',
        'entry': 'Entry effect'
    }
    settings['zorder'] = {
        'total': 0,
        'quantity': 40,
        'quality': 20,
        'entry': 60
    }

    # Set up the figure
    f = plt.figure(figsize=(6.3, 6.3/2))

    ax = dict()
    ax[('left_0')] = f.add_subplot(1, 2, 1)
    ax[('left_1')] = ax[('left_0')].twinx()
    ax[('right_0')] = f.add_subplot(1, 2, 2)
    ax[('right_1')] = ax[('right_0')].twinx()

    # Add container to save plots for the legend
    legend_gas = Line2D([0], [0], color='black', label='Gas Price (\$)', lw=2)
    legend = {
        'left': {'Gas Price (\$)': legend_gas},
        'right': {'Gas Price (\$)': legend_gas}
    }

    # Get the names of series to be plotted
    names = {'left': ['total'], 'right': dict()}
    if 'entry' in series:
        names['right'] = ['entry', 'quality', 'quantity']
    else:
        names['right'] = ['quantity', 'quality']

    # Do plots
    for i in ['left', 'right']:
        # 1. Plot the series
        # Gas price
        ax[i + '_1'].plot(series['gas_price'], color='black')

        # Welfare
        for name in names[i]:
            ax[i + '_0'].plot(
                diff[name],
                alpha=0
            )
            legend[i][settings['labels'][name]] = ax[i + '_0'].fill_between(
                diff[name].index,
                diff[name],
                0,
                facecolor=settings['colors'][name],
                alpha=1,
                label=settings['labels'][name],
                zorder=settings['zorder'][name]
            )

        # 2. Add in settings
        ax[i + '_0'].set_ylabel('Change in Welfare (\%)', labelpad=0)
        ax[i + '_1'].set_ylabel('Natural Gas Price (\$)', labelpad=0)
        ax[i + '_0'].set_xticks([0, 60, 120])
        ax[i + '_0'].set_xticklabels([2000, 2005, 2010])
        ax[i + '_0'].axhline(y=0, c="black", linewidth=0.5)
        ax[i + '_0'].set_xlabel('Year', labelpad=0)
        ax[i + '_0'].legend(
            list(legend[i].values()),
            list(legend[i].keys()),
            frameon=False,
            loc='lower left'
        )

    # Add in settings for the y axes
    for i in settings['ylim']:
        ax[i].set_ylim(settings['ylim'][i])
        ax[i].set_yticks(settings['yticks'][i])

    # Titles
    ax['left_0'].set_title('(a) Total Change', pad=10)
    ax['right_0'].set_title('(b) Decomposition', pad=10)

    # Save
    f.tight_layout(pad=1)
    plt.savefig(output_path, pad_inches=0)
    #f.show()


def counterfactual_table(series, comparison, overleaf_path, fig_name):
    """ Build the table that summarizes the counterfactuals

    Inputs:
        series (df): the series used in the graph. This should contain
            the following columns:
                - boom (boolean series of where the market is in a boom)
                - initial (the initial total series)
                - entry (baseline plus entry effect)
                - quality (baseline plus entry and quality effect)
                - final (the final series with all changes)
            Note that the series which are inputted should be in dollar
            values and not percentage change. Also the indexes of these
            series should be the date.

        comparison (str): either 'initial' or 'final' - what the
            comparison of the difference should be made with respect to.

        fig_name (str): name for the figure to be saved as:

            ./reports/tables/table_' + fig_name + '.tex'

    Outputs:
        None (does save the figure to ./reports/tables/)
    """

    # Set up the data
    df_boom = series.groupby('boom').sum()
    df_boom = pd.concat([df_boom, pd.DataFrame([series.sum()], index=[2])], ignore_index=True)

    df_boom['percent_diff_total'] = \
        100 * (df_boom['final'] - df_boom['initial']) / (series[comparison].sum())
    df_boom['percent_diff_entry'] = \
        100 * (df_boom['entry'] - df_boom['initial']) / (series[comparison].sum())
    if (fig_name == 'benchmark') | (fig_name == 'intermediary'):
        df_boom['percent_diff_entry'] = 0.0
    df_boom['percent_diff_quality'] = \
        100 * (df_boom['quality'] - df_boom['entry']) / (series[comparison].sum())
    df_boom['percent_diff_quantity'] = \
        100 * (df_boom['final'] - df_boom['quality']) / (series[comparison].sum())

    table = df_boom[[
        'percent_diff_entry',
        'percent_diff_quality',
        'percent_diff_quantity',
        'percent_diff_total'
    ]]

    table = table.T.rename(
        {0: 'bust',
         1: 'boom',
         2: 'average'},
        axis=1
    )
    table = table[['boom', 'bust', 'average']].round(1)

    # Write to the tex format
    with open('src/tex/table_counterfactual.tex', 'r') as f:
        tex = f.read()

    rows = ['quality', 'quantity', 'entry', 'total']
    cols = ['boom', 'bust', 'average']
    vals = dict()
    for row, col in itertools.product(rows, cols):
        item = table.loc['percent_diff_' + row, col]
        if round(item, 1) == -0.0:
           item = 0.0
        vals[row + '_' + col] = item
        param_name = overleaf_path + f"parameters/{fig_name}_{row}_{col}.tex"
        with open(param_name, "wb") as f:
            f.write((str(item) + '\%').encode())

    if fig_name != 'benchmark_fortnight':
        output = tex.format(**vals)
        with open(overleaf_path + f'tables/table_{fig_name}.tex', 'w') as f:
            f.write(output)

    return vals


def counterfactual_absolute(series, comparison, overleaf_path, fig_name):
    df_boom = series.sum()
    df_boom['absolute_diff_total'] = \
        int(round(30 * (df_boom['final'] - df_boom['initial']).sum(), 0))
    df_boom['absolute'] = round(30 * df_boom['final'] / 1000, 1)

    for k in ['absolute_diff_total', 'absolute']:
        item = df_boom.loc[k]
        param_name = overleaf_path + f"parameters/{fig_name}_{k}.tex"
        with open(param_name, "wb") as f:
            if k == 'absolute_diff_total':
                f.write((str(int(item))).encode())
            else:
                f.write((str(item)).encode())


def robustness_two_weeks_table(vals_benchmark, vals_benchmark_fortnight, overleaf_path):
    vals = {
        'baseline_sorting': vals_benchmark,
        'fortnight_sorting': vals_benchmark_fortnight
    }

    with open('src/tex/table_two_week_counterfactual.tex', 'r') as f:
        tex = f.read()

    output = tex.format(**vals)
    with open(overleaf_path + f'tables/table_two_week_counterfactual.tex', 'w') as f:
        f.write(output)


def robustness_rig_target_table(vals_benchmark, vals_benchmark_rig_target, overleaf_path):
    vals = {
        'baseline_sorting': vals_benchmark['total_average'],
        'rig_target_sorting': vals_benchmark_rig_target['total_average']
    }

    with open('src/tex/table_rig_target_counterfactual.tex', 'r') as f:
        tex = f.read()

    output = tex.format(**vals)
    with open(overleaf_path + f'tables/table_rig_target_counterfactual.tex', 'w') as f:
        f.write(output)


def get_value_search_df(estimator):
    df_all = dict()
    for spec in ['low', 'mid', 'high']:
        df_spec = pd.DataFrame(estimator.v_evol_dict[spec]).T
        df_spec.index = pd.to_datetime(df_spec.index)
        df_spec = df_spec.stack().sort_index().reset_index()
        df_spec = df_spec[df_spec['level_1'] <= 12]
        df_spec['12_month'] = df_spec[['level_0', 'level_1']].apply(
            lambda x: x['level_0'] + pd.DateOffset(months=x['level_1']),
            axis=1
        )
        df_spec.set_index(['level_0', '12_month'], inplace=True)
        df_all[spec] = df_spec[0]

    # * 30 converts from dayrate per month to total (in millions of USD)
    df = pd.DataFrame(df_all).stack() * 30
    df.index.names = ['date', 'expectation', 'spec']

    return df


def get_value_function_data(df_state, const, r,
                            search_grid_by_spec, search_value_by_spec):

    state_evolution_by_date = dict()
    for row in df_state.itertuples():
        date_beliefs = copy.deepcopy(row.date)
        state = np.array([row.g, row.n_l, row.n_m, row.n_h])
        for t in range(12):
            state_evolution_by_date[(row.date, date_beliefs)] = \
                copy.deepcopy(states.next_state(
                    state,
                    const=const.values.T[0],
                    r=r.values,
                    sigma=0)
                )
            state = copy.deepcopy(state_evolution_by_date[(row.date, date_beliefs)])
            date_beliefs = date_beliefs + pd.DateOffset(months=1)
    df_state_evolution_by_date = pd.DataFrame(state_evolution_by_date).T
    df_state_evolution_by_date.columns = ['g', 'n_l', 'n_m', 'n_h']
    df_state_evolution_by_date.index.names = ['date', 'expectation']

    for spec in ['low', 'mid', 'high']:
        df_state_evolution_by_date[f'v_{spec}'] = eval_linear(
            search_grid_by_spec[spec],
            search_value_by_spec[spec],
            np.array(
                df_state_evolution_by_date[['g', 'n_l', 'n_m', 'n_h']],
                order='C'
            )
        ) * 30

    return df_state_evolution_by_date

def value_search_graph(df, output_path):
    settings = {
        'color': {
            'low': 'dodgerblue',
            'mid': 'blue',
            'high': 'black'
        },
        'linestyle': {
            'low': '-.',
            'mid': '--',
            'high': '-',
        },
        'label': {
            'low': 'Low-spec',
            'mid': 'Mid-spec',
            'high': 'High-spec'
        }
    }

    df_value = df[
        (df.index.get_level_values('date') \
         == df.index.get_level_values('expectation'))
    ]

    plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex  # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)

    fig = plt.figure(figsize=(6.3, 6))
    fig.tight_layout(pad=0)

    ax = dict()
    ax['high'] = fig.add_subplot(3, 1, 1)
    ax['mid'] = fig.add_subplot(3, 1, 2)
    ax['low'] = fig.add_subplot(3, 1, 3)

    lines = dict()
    labels = dict()
    for spec in ['low', 'mid', 'high']:
        lines[spec] = [
            Line2D(
                    [0],
                    [0],
                    color=settings['color'][spec],
                    linestyle=settings['linestyle'][spec],
                    lw=1
                ),
            Line2D([0], [0], color='lightgray', lw=1)
            ]
        labels[spec] = [settings['label'][spec], 'Expectations (12 months)']

    graphs = dict()
    for spec in ['low', 'mid', 'high']:
        for date in df.index.get_level_values('date').unique():
            df_expectation_spec = df.xs((date))
            high, = ax[spec].plot(
                df_expectation_spec[f'v_{spec}'],
                color='lightgray',
                linewidth=1,
                linestyle="-",
                marker=""
            )

        graphs[spec], = ax[spec].plot(
            df_value.index.get_level_values('date'),
            df_value[f"v_{spec}"],
            color=settings['color'][spec],
            linewidth=1.5,
            linestyle=settings['linestyle'][spec],
            marker="",
            label=settings['label'][spec]
        )
        ax[spec].set_xticks(
            ticks=['2000-01-01', '2005-01-01', '2009-12-31']
        )

        ax[spec].set_xticklabels(
            labels=['2000', '2005', '2010']
        )

        ax[spec].set_ylabel('Millions of USD')
        ax[spec].legend(lines[spec], labels[spec], frameon=False, loc='upper left')

    plt.xlabel('Year')

    # Save
    plt.savefig(output_path, pad_inches=0)
    #plt.show()


def smm_to_tex(params_other, params_match, delta,
               ci_lower_params_other, ci_upper_params_other,
               input_path, output_path, overleaf_path):
    """ Takes the bootstrap data and formats it into a tex file for the
    paper.

    Args:
        params_original (df): Optimal parameters at the true data
        params_bootstrap (df): Optimal parameters from resampled data

    Returns:
        (string): output to write

    """

    # Format the values nicely
    tex_input = dict()
    for v in params_other:
        if ((v is 'd_0') | (v is 'd_1')):
            tex_input[v] = round(params_other[v], 1)
            if ci_lower_params_other is None:
                tex_input['ql_' + v] = None
                tex_input['qu_' + v] = None
            else:
                tex_input['ql_' + v] = round(ci_lower_params_other[v], 1)
                tex_input['qu_' + v] = round(ci_upper_params_other[v], 1)
        elif v in ['rho_0', 'rho_1', 'rho_2']:
            tex_input[v] = round(params_other[v], 3)
            if ci_lower_params_other is None:
                tex_input['ql_' + v] = None
                tex_input['qu_' + v] = None
            else:
                tex_input['ql_' + v] = round(ci_lower_params_other[v], 3)
                tex_input['qu_' + v] = round(ci_upper_params_other[v], 3)
        elif v in ['rho_3']:
            tex_input[v] = round(params_other[v], 4)
            if ci_lower_params_other is None:
                tex_input['ql_' + v] = None
                tex_input['qu_' + v] = None
            else:
                tex_input['ql_' + v] = round(ci_lower_params_other[v], 4)
                tex_input['qu_' + v] = round(ci_upper_params_other[v], 4)
        elif v is 'c':
            tex_input[v] = round(params_other[v] * 30, 2)
            tex_input['ql_' + v] = 'n.a.'
            tex_input['qu_' + v] = 'n.a.'
        else:
            tex_input[v] = round(params_other[v], 2)
            if ci_lower_params_other is None:
                tex_input['ql_' + v] = None
                tex_input['qu_' + v] = None
            else:
                tex_input['ql_' + v] = round(ci_lower_params_other[v], 2)
                tex_input['qu_' + v] = round(ci_upper_params_other[v], 2)

    for i in params_match:
        tex_input[f'{i}'] = round(params_match[i] * 1000, 1)
        tex_input[f'ql_{i}'] = round(ci_lower_params_other[i] * 1000, 1)
        tex_input[f'qu_{i}'] = round(ci_upper_params_other[i] * 1000, 1)

    tex_input['delta'] = round(delta, 2)

    # Get the tex format
    with open(input_path, 'r') as f:
        tex = f.read()

    output = tex.format(**tex_input)

    with open(output_path, 'w') as f:
        f.write(output)

    # Write the parameters
    for i in params_other:
        item = tex_input[i]
        param_name = overleaf_path + f"parameters/{i}.tex"
        with open(param_name, "wb") as f:
            f.write((str(item)).encode())


def smm_moments_table(moments_sim, moments_data_non_coef, moments_data_coef, mean_reneg,
                      input_path, output_path):

    moments_data = pd.concat([moments_data_coef, moments_data_non_coef])['0']
    moments_data['reneg'] = mean_reneg
    moments_data['diff_high_mid'] = moments_data['price_high'] - moments_data[
        'price_mid']
    moments_data['diff_mid_low'] = moments_data['price_mid'] - moments_data['price_low']
    moments_sim['diff_high_mid'] = moments_sim['price_high'] - moments_sim['price_mid']
    moments_sim['diff_mid_low'] = moments_sim['price_mid'] - moments_sim['price_low']

    rename_index = {
        'Intercept': 'Intercept',
        'C(spec, Treatment(reference="low"))[T.mid]': 'beta_mid',
        'C(spec, Treatment(reference="low"))[T.high]': 'beta_high',
        'mri:C(spec, Treatment(reference="low"))[low]': 'beta_mri_low',
        'mri:C(spec, Treatment(reference="low"))[mid]': 'beta_mri_mid',
        'mri:C(spec, Treatment(reference="low"))[high]': 'beta_mri_high',
        'g:value': 'beta_2',
        'diff_high_mid': 'diff_high_mid',
        'diff_mid_low': 'diff_mid_low'
    }

    rounder = smarter_round(1)

    for i in rename_index:
        moments_data[i] = rounder(moments_data[i] * 10)
        moments_sim[i] = rounder(moments_sim[i] * 10)

    # Add prefix to series
    moments_sim = (
        moments_sim
        .rename(rename_index)
        .add_prefix(prefix='simulated_')
        .apply(smarter_round(2))
        .to_dict()
    )
    moments_data = (
        moments_data
        .rename(rename_index)
        .add_prefix(prefix='empirical_')
        .apply(smarter_round(2))
        .to_dict()
    )

    moments_all = {**moments_sim, **moments_data}

    # Get the tex format
    with open(input_path, 'r') as f:
        tex = f.read()

    output = tex.format(**moments_all)

    with open(output_path, 'w') as f:
        f.write(output)

    return moments_all


def sorting_fit(moments_all, overleaf_path):

    # Make the settings for the graph
    plt.rcParams['text.latex.preamble']=r"\usepackage{lmodern}" # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)
    plt.rc('font', size=9)          # controls default text sizes
    plt.rc('axes', titlesize=9)     # fontsize of the axes title
    plt.rc('axes', labelsize=9)     # fontsize of the x and y labels
    plt.rc('xtick', labelsize=9)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=9)    # fontsize of the tick labels
    plt.rc('legend', fontsize=8)    # legend fontsize
    plt.rc('figure', titlesize=10)  # figure title size

    # Setup the data for figure
    a = dict()
    for i in ['empirical', 'simulated']:
        for b in ['boom', 'bust']:
            a[(i, b)] = {
                1: moments_all[f'{i}_mri_mean_low_{b}'],
                2: moments_all[f'{i}_mri_mean_mid_{b}'],
                3: moments_all[f'{i}_mri_mean_high_{b}']
            }
    df = pd.DataFrame(a).unstack()
    df.index.names = ['type', 'b', 'spec']

    # Do the figure
    f = plt.figure(figsize=(6.3, 3))

    ax = dict()
    ax['left'] = f.add_subplot(1, 3, 1)
    ax['left'].plot(
        df.xs(
            ('empirical', 'bust'),
            level=['type', 'b']
        ),
        marker='.', color='blue', linewidth=1.5, linestyle='--', label="Data", alpha = 0.6
    )
    ax['left'].plot(
        df.xs(
            ('simulated', 'bust'),
            level=['type', 'b']
        ),
        marker='x', color='blue', linewidth=1.5, label="Simulation", alpha = 0.4
    )

    ax['center'] = f.add_subplot(1, 3, 2)
    ax['center'].plot(
        df.xs(
            ('empirical', 'boom'),
            level=['type', 'b']
        ),
        marker='.', color='dodgerblue', linewidth=1.5, linestyle='--', label="Data", alpha = 0.4
    )
    ax['center'].plot(
        df.xs(
            ('simulated', 'boom'),
            level=['type', 'b']
        ),
        marker='x', color='dodgerblue', linewidth=1.5, label="Simulation", alpha=0.4
    )

    ax['right'] = f.add_subplot(1, 3, 3)
    ax['right'].plot(
        df.xs(
            ('simulated', 'bust'),
            level=['type', 'b']
        ),
        marker='.', color='blue', linewidth=1.5, label="Sim: Bust", alpha=0.4
    )
    ax['right'].plot(
        df.xs(
            ('simulated', 'boom'),
            level=['type', 'b']
        ),
        marker='.', color='dodgerblue', linewidth=1.5, label="Sim: Boom", alpha=0.4
    )

    # Settings
    for graph in ['left', 'center', 'right']:
        ax[graph].set_yticks([0.7, 0.8, 0.9, 1.0])
        ax[graph].set_ylim([0.7, 1.0])
        ax[graph].set_xticks([1, 2, 3])
        ax[graph].set_xlim([0.9, 3.1])
        ax[graph].set_xticklabels(['low', 'mid', 'high'])
        ax[graph].legend(frameon=False, loc='lower right')
        ax[graph].set_xlabel('Rig Type')

    ax['left'].set_title('(a) Sorting: Bust')
    ax['center'].set_title('(b) Sorting: Boom')
    ax['right'].set_title('(c) Sorting: Rotation')

    f.tight_layout(pad=1)
    rc('text', usetex=True)

    plt.savefig(
        overleaf_path + 'figures/figure_sorting_fit.pdf',
        pad_inches=0
    )


def graph_matching_patterns(data, output_path):
    """

    """
    plt.rc('font', size=9)  # controls default text sizes
    plt.rc('axes', titlesize=11)  # fontsize of the axes title
    plt.rc('axes', labelsize=10)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=9)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=9)  # fontsize of the tick labels
    plt.rc('legend', fontsize=9)  # legend fontsize
    plt.rc('figure', titlesize=11)  # figure title size

    y = 'max_wd'

    f = plt.figure(figsize=(7, 4.8))

    # Left panel
    data_av = data[['mri', y]].groupby(y).mean()
    data_av.reset_index(inplace=True)

    ax0 = plt.subplot(1, 2, 1)
    ax0.scatter(data_av[y], data_av['mri'], marker='o', color='blue', s=6)
    ax0.set_ylim([0, 2.0])
    ax0.set_xlim([0, 500])
    ax0.spines['top'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.set_ylabel('Well Complexity Index Ranking')
    ax0.set_xticks([100, 200, 300, 400, 500])
    ax0.set_yticks([0, 0.5, 1.0, 1.5, 2.0])
    ax0.set_xlabel('Rig Type Ranking (Max. Drilling Depth)')

    # Left panel
    data_av = data[['mri', y]].groupby(y).quantile(0.1)
    data_av['mri_0.95'] = data[['mri', y]].groupby(y).quantile(0.9)
    data_av.reset_index(inplace=True)

    ax1 = plt.subplot(1, 2, 2)
    for i in range(0, len(data_av)):
        ax1.vlines(x=data_av.loc[i, y], ymin=data_av.loc[i, 'mri'],
                   ymax=data_av.loc[i, 'mri_0.95'], color='blue', linewidths=0.6)
        ax1.scatter(x=data_av.loc[i, y], y=data_av.loc[i, 'mri'], marker='_', color='blue', s=13)
        ax1.scatter(x=data_av.loc[i, y], y=data_av.loc[i, 'mri_0.95'], marker='_',
                    color='blue', s=13)
    ax1.set_ylim([0, 4.0])
    ax1.set_xlim([0, 500])
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_ylabel('Well Complexity Index Ranking')
    ax1.set_xticks([100, 200, 300, 400, 500])
    ax1.set_xlabel('Rig Type Ranking (Max. Drilling Depth)')

    # Titles
    ax0.set_title('(a) Average match for each rig type', pad=10)
    ax1.set_title('(b) Matching range (10\% to 90\%)', pad=10)

    plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)

    f.tight_layout(pad=4)

    plt.savefig(
       output_path,
        bbox_inches='tight'
    )
    #plt.show()


def graph_oil_vs_gas(df_prop, df_gas, output_path):
    """

    """
    # Make the settings for the graph
    plt.rcParams['text.latex.preamble']=r"\usepackage{lmodern}" # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)
    plt.rc('font', size=9)          # controls default text sizes
    plt.rc('axes', titlesize=9)     # fontsize of the axes title
    plt.rc('axes', labelsize=9)     # fontsize of the x and y labels
    plt.rc('xtick', labelsize=9)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=9)    # fontsize of the tick labels
    plt.rc('legend', fontsize=8)    # legend fontsize
    plt.rc('figure', titlesize=10)  # figure title size

    f = plt.figure(figsize=(7, 4.8))

    # Left panel
    ax0 = plt.subplot(1, 2, 1)
    ax0.hist(df_prop['prop'], density=True, bins=50, color='blue')
    #ax0.set_ylim([0, 2.0])
    ax0.set_xlim([0, 1])
    ax0.spines['top'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.set_ylabel('Density')
    ax0.set_xticks([0, 0.5, 1.0])
    #ax0.set_yticks([0, 0.5, 1.0])
    ax0.set_xlabel('Proportion')

    # Right panel
    ax1 = plt.subplot(1, 2, 2)
    ax1.plot(df_gas['month'], df_gas['quantity_all_gas'],
             color='blue', label='Natural Gas Price Only')
    ax1.plot(df_gas['month'], df_gas['quantity_prop'],
             color='dodgerblue', label='Both Prices')

    #ax0.set_ylim([0, 2.0])
    #ax0.set_xlim([0, 1])
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_ylabel('Value of 1 Equivalent Barrel of Oil')
    ax1.set_xticks([
        pd.to_datetime('2000-01-01'),
        pd.to_datetime('2005-01-01'),
        pd.to_datetime('2010-01-01')
    ])
    ax1.set_xticklabels(['2000', '2005', '2010'])
    ax1.set_yticks([0, 0.5, 1.0, 1.5, 2.0, 2.5])
    ax1.set_xlabel('Year')
    ax1.legend(frameon=False, loc='lower right')

    # Titles
    ax0.set_title('(a) Proportion of Production - Natural Gas', pad=10)
    ax1.set_title('(b) Differences in Value', pad=10)

    plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex  # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)

    f.tight_layout(pad=4)

    plt.savefig(output_path, bbox_inches='tight')


def graph_boom_bust(ci, data, output_path):
    plt.rc('font', size=9)  # controls default text sizes
    plt.rc('axes', titlesize=11)  # fontsize of the axes title
    plt.rc('axes', labelsize=10)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=9)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=9)  # fontsize of the tick labels
    plt.rc('legend', fontsize=9)  # legend fontsize
    plt.rc('figure', titlesize=11)  # figure title size

    f = plt.figure(figsize=(7, 4.2))

    # Left panel
    ax0 = plt.subplot(1, 2, 1)
    # Left panel
    ax0 = plt.subplot(1, 2, 1)

    data_boom = data.groupby(['spec', 'boom'])['mri'].mean()

    y_bust = [
        data_boom.xs(('low', False)),
        data_boom.xs(('mid', False)),
        data_boom.xs(('high', False))
        ]

    y_boom = [
            data_boom.xs(('low', True)),
            data_boom.xs(('mid', True)),
            data_boom.xs(('high', True))
            ]

    # Get average
    data_av = data.groupby(['spec'])['mri'].mean()
    y_av = [
        data_av.xs('low'),
        data_av.xs('mid'),
        data_av.xs('high'),
    ]
    x = [1, 2, 3]

    ax0.plot(x, y_bust, marker='.', color='blue', linewidth=1.3, label="Bust", markersize=8)
    #ax0.plot(x, y_av, marker='.', color='silver', linewidth=1.2, label="Average")
    ax0.plot(x, y_boom, marker='.', color='dodgerblue', linewidth=1.3, label="Boom", linestyle='--',
             markersize=8)
    ax0.set_ylabel('Well Complexity Index')
    ax0.set_xlabel('Rig Type')
    ax0.set_yticks([0.7, 0.8, 0.9, 1.0])
    ax0.set_ylim([0.68, 1.02])
    ax0.set_xlim([0.9, 3.1])
    ax0.set_xticks(x)
    ax0.set_xticklabels(['Low-spec', 'Mid-spec', 'High-spec'])
    ax0.spines['right'].set_visible(False)
    ax0.spines['top'].set_visible(False)

    # Add in the 95% confidence intervals
    ci['diff'] = (ci[1] - ci[0])/2
    yerr = [
        ci.loc['C(spec)[low]', 'diff'],
        ci.loc['C(spec)[mid]', 'diff'],
        ci.loc['C(spec)[high]', 'diff'],
    ]
    #plt.errorbar(x, y_bust, yerr=yerr, fmt='o', color='blue', capsize=6.0,
     #            ecolor='silver', elinewidth=3, markersize=1.1, alpha=1.0,
      #           label='95\% Confidence Interval')
    plt.legend(frameon=False)

    # Right panel
    ax1 = plt.subplot(1, 2, 2)

    x = data[data['boom'] == 0]['mri']
    y = data[data['boom'] == 1]['mri']
    bins = np.linspace(0, 2.1, 30)
    weights_x = np.ones_like(x) / float(len(x))
    weights_y = np.ones_like(y) / float(len(y))
    n, bins, patches = ax1.hist([x, y], bins, label=['Bust: Complexity', 'Boom: Complexity'],
                                weights=[weights_x, weights_y], color=['blue', 'silver'])
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.set_ylim([0, 0.12])
    ax1.set_xlim([0, 2.1])
    ax1.set_xticks([0, 0.5, 1, 1.5, 2])
    ax1.set_yticks([0, 0.05, 0.1])
    ax1.set_xlabel('Well Complexity Index')
    ax1.set_ylabel("Probability")
    ax1.legend(frameon=False)

    # Titles
    ax0.set_title('(a) Average match for each rig type', pad=10)
    ax1.set_title('(b) Distribution of well complexity', pad=10)

    plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex  # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)

    f.tight_layout(pad=4)
    plt.savefig(
        output_path,
        bbox_inches='tight'
    )
    #plt.show()


def graph_boom_bust_composition(ci, data, output_path):
    plt.rc('font', size=9)  # controls default text sizes
    plt.rc('axes', titlesize=11)  # fontsize of the axes title
    plt.rc('axes', labelsize=10)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=9)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=9)  # fontsize of the tick labels
    plt.rc('legend', fontsize=9)  # legend fontsize
    plt.rc('figure', titlesize=11)  # figure title size

    f = plt.figure(figsize=(4.2, 4.2))

    # Left panel
    ax1 = plt.subplot(1, 1, 1)

    x = data[data['boom'] == 0]['mri']
    y = data[data['boom'] == 1]['mri']
    bins = np.linspace(0, 2.1, 30)
    weights_x = np.ones_like(x) / float(len(x))
    weights_y = np.ones_like(y) / float(len(y))
    n, bins, patches = ax1.hist([x, y], bins, label=['Bust: Complexity', 'Boom: Complexity'],
                                weights=[weights_x, weights_y], color=['blue', 'silver'])
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.set_ylim([0, 0.12])
    ax1.set_xlim([0, 2.1])
    ax1.set_xticks([0, 0.5, 1, 1.5, 2])
    ax1.set_yticks([0, 0.05, 0.1])
    ax1.set_xlabel('Well Complexity Index')
    ax1.set_ylabel("Probability")
    ax1.legend(frameon=False)

    # Titles
    # ax0.set_title('(a) Average match for each rig type', pad=10)
    # ax1.set_title('Distribution of well complexity', pad=10)

    plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex  # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)

    f.tight_layout(pad=2)
    plt.savefig(
        output_path,
        bbox_inches='tight'
    )
    #plt.show()


def graph_boom_bust_2(ci, data, output_path):
    plt.rc('font', size=9)  # controls default text sizes
    plt.rc('axes', titlesize=11)  # fontsize of the axes title
    plt.rc('axes', labelsize=10)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=9)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=9)  # fontsize of the tick labels
    plt.rc('legend', fontsize=9)  # legend fontsize
    plt.rc('figure', titlesize=11)  # figure title size

    f = plt.figure(figsize=(7, 4.2))

    # Left panel
    ax0 = plt.subplot(1, 2, 1)

    # Get boom/bust
    data_boom = data.groupby(['spec', 'boom'])['mri'].mean()
    y_bust = [
        data_boom.xs(('low', False)),
        data_boom.xs(('mid', False)),
        data_boom.xs(('high', False))
    ]

    y_boom = [
            data_boom.xs(('low', True)),
            data_boom.xs(('mid', True)),
            data_boom.xs(('high', True))
        ]
    y_diff = [
            data_boom.xs(('low', True)) - data_boom.xs(('low', False)),
            data_boom.xs(('mid', True)) - data_boom.xs(('mid', False)),
            data_boom.xs(('high', True)) - data_boom.xs(('high', False))
        ]

    # Get average
    data_av = data.groupby(['spec'])['mri'].mean()
    y_av = [
        data_av.xs('low'),
        data_av.xs('mid'),
        data_av.xs('high'),
    ]
    x = [1, 2, 3]

    ax0.plot(x, y_bust, marker='.', color='blue', linewidth=1.3, label="Bust", markersize=8)
    #ax0.plot(x, y_av, marker='.', color='silver', linewidth=1.2, label="Average")
    ax0.plot(x, y_boom, marker='.', color='dodgerblue', linewidth=1.3, label="Boom", linestyle='--',
             markersize=8)
    ax0.set_ylabel('Well Complexity Index')
    ax0.set_xlabel('Rig Type')
    ax0.set_yticks([0.7, 0.8, 0.9, 1.0])
    ax0.set_ylim([0.68, 1.02])
    ax0.set_xlim([0.9, 3.1])
    ax0.set_xticks(x)
    ax0.set_xticklabels(['Low-spec', 'Mid-spec', 'High-spec'])
    ax0.spines['right'].set_visible(False)
    ax0.spines['top'].set_visible(False)
    plt.legend(frameon=False)

    # Right panel
    ax1 = plt.subplot(1, 2, 2)
    # Add in the 95% confidence intervals
    ci['diff'] = (ci[1] - ci[0]) / 2
    yerr = [
        ci.loc['C(spec)[low]', 'diff'],
        ci.loc['C(spec)[mid]', 'diff'],
        ci.loc['C(spec)[high]', 'diff'],
    ]
    ax1.errorbar(x, y_diff, yerr=yerr, fmt='o', color='blue', capsize=6.0,
                 ecolor='silver', elinewidth=3, markersize=1.1, alpha=1.0,
                 label='95\% Confidence Interval')

    plt.legend(frameon=False)
    ax1.plot(x, y_diff, marker='.', color='black', linewidth=1.3, label="Bust",
             markersize=8)
    ax1.axhline(0.0, color='gray', ls='--', linewidth=1.1)
    # ax0.plot(x, y_av, marker='.', color='silver', linewidth=1.2, label="Average")
    ax1.set_ylabel('Difference in Well Complexity Index')
    ax1.set_xlabel('Rig Type')
    ax1.set_yticks([-0.15, 0, 0.15])
    ax1.set_ylim([-0.17, 0.17])
    ax1.set_xlim([0.9, 3.1])
    ax1.set_xticks(x)
    ax1.set_xticklabels(['Low-spec', 'Mid-spec', 'High-spec'])
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)

    # Titles
    ax0.set_title('(a) Average match for each rig type', pad=10)
    ax1.set_title('(b) Differences in average match in boom vs bust', pad=10)

    plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex  # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)

    f.tight_layout(pad=4)
    plt.savefig(
        output_path,
        bbox_inches='tight'
    )
    #plt.show()


def build_table_summary(
        df_contracts,
        df_state,
        overleaf_path,
        input_path='./src/tex/table_summary.tex',
        output_path='./reports/tables/table_summary.tex'
    ):
        summary_stats = ['count', 'mean', 'std', '10%', '90%']
        summary_to_tex = dict()
        mean_gas = df_state['gas'].mean()

        # Get stats for duration / dayrate variables
        df_contracts = df_contracts.rename(
            columns={
                'total_days_description': 'duration',
                'day_rate': 'dayrate',
                'water_depth': 'waterd'
            }
        )
        df_state = df_state.rename(
            columns={
                'utilization_high': 'util_high',
                'utilization_mid': 'util_mid',
                'utilization_low': 'util_low'
            }
        )

        df_summary_reneg = df_contracts.groupby(['reneg'])[[
            'duration',
            'dayrate'
        ]].describe(percentiles=[0.1, 0.9]).stack()

        for reneg_numeric, reneg_text in zip([0, 1], ['new', 'reneg']):
            for summary_stat, metric in itertools.product(
                    summary_stats,
                    ['duration', 'dayrate']
            ):
                tex_key = '{}_{}_{}'.format(summary_stat, reneg_text, metric)
                summary_to_tex[tex_key] = float(
                    df_summary_reneg.xs(
                        (reneg_numeric, summary_stat),
                        level=[0, 1]
                    )[metric]
                )

                # Do rounding
                if summary_stat == 'count':
                   summary_to_tex[tex_key] = int(summary_to_tex[tex_key])
                elif metric == 'duration':
                   summary_to_tex[tex_key] = int(summary_to_tex[tex_key])
                elif metric == 'dayrate':
                   summary_to_tex[tex_key] = int(summary_to_tex[tex_key] * 1000)

        # Get stats for other variables
        df_summary_all = df_contracts.describe(percentiles=[0.1, 0.9]).stack()
        for summary_stat, metric in itertools.product(
                summary_stats,
                ['mri', 'value', 'waterd']
        ):
            tex_key = '{}_{}'.format(summary_stat, metric)
            summary_to_tex[tex_key] = df_summary_all.xs(
                    (summary_stat, metric),
                    level=[0, 1]
                )[0]

            # Do rounding
            if summary_stat == 'count':
                summary_to_tex[tex_key] = int(summary_to_tex[tex_key])
            elif metric == 'mri':
                summary_to_tex[tex_key] = round(summary_to_tex[tex_key], 2)
            elif metric == 'waterd':
                summary_to_tex[tex_key] = int(summary_to_tex[tex_key])
            elif metric == 'value':
                if summary_stat == '10%':
                    summary_to_tex[tex_key] = round(summary_to_tex[tex_key] * mean_gas * 30, 2)
                elif summary_stat == 'mean':
                    summary_to_tex[tex_key] = round(summary_to_tex[tex_key] * mean_gas * 30, 1)
                else:
                    summary_to_tex[tex_key] = int(summary_to_tex[tex_key] * mean_gas * 30)

        # Get stats for utilization
        df_summary_state = df_state[
            ['util_low', 'util_mid', 'util_high']
        ].stack().describe(percentiles=[0.1, 0.9])

        for stat in summary_stats:
            tex_key = '{}_{}'.format(stat, 'util')
            summary_to_tex[tex_key] = df_summary_state[stat]

            # Do rounding
            if stat == 'count':
                summary_to_tex[tex_key] = int(summary_to_tex[tex_key])
            else:
                summary_to_tex[tex_key] = round(summary_to_tex[tex_key], 2)

        with open(input_path, 'r') as f:
            tex = f.read()
            output = tex.format(**summary_to_tex)

        with open(output_path, 'w') as f:
            f.write(output)

        # Save some numbers for future referencing
        with open(overleaf_path + f'parameters/parameter_mean_reneg_duration.tex', 'w') as fout:
            fout.write(f"{summary_to_tex['mean_reneg_duration']}")

        with open(overleaf_path + f'parameters/parameter_mean_new_duration.tex', 'w') as fout:
            fout.write(f"{summary_to_tex['mean_new_duration']}")

        with open(overleaf_path + f'parameters/parameter_std_reneg_duration.tex', 'w') as fout:
            fout.write(f"{summary_to_tex['std_reneg_duration']}")

        with open(overleaf_path + f'parameters/parameter_std_new_duration.tex', 'w') as fout:
            fout.write(f"{summary_to_tex['std_new_duration']}")


def build_table_transitions(transitions_params, transitions_errors,
                            input_path, output_path, overleaf_path):

    tex_input = {**transitions_params, **transitions_errors}
    with open(input_path, 'r') as f:
        tex = f.read()
        output = tex.format(**tex_input)
    with open(output_path, 'w') as f:
        f.write(output)

    for i in tex_input:
        with open(overleaf_path + f'parameters/parameter_{i}.tex', 'w') as fout:
            fout.write(f"{tex_input[i]}")


def graph_acceptance_sets(cutoffs_min, cutoffs_max, output_path):
    f = plt.figure(figsize=(6, 4))

    titles = {
        'low': 'Low-spec',
        'mid': 'Mid-spec',
        'high': 'High-spec'
    }

    ax = dict()
    for k, spec in enumerate(['low', 'mid', 'high']):
        ax[spec] = f.add_subplot(1, 3, k + 1)
        for tau in [2]:
            ax[spec].fill_between(
                x=cutoffs_min[f"('{spec}', {tau})"].index,
                y1=cutoffs_min[f"('{spec}', {tau})"],
                y2=cutoffs_max[f"('{spec}', {tau})"],
                color='lightgray'
            )
            ax[spec].plot(cutoffs_min[f"('{spec}', {tau})"], color='black')
            ax[spec].plot(cutoffs_max[f"('{spec}', {tau})"], color='black')
            ax[spec].set_xlabel('Date')
            ax[spec].set_ylabel('Well Complexity')
            ax[spec].set_title(titles[spec])
            ax[spec].set_xticks(['2000-01-01', '2005-01-01', '2009-12-31'])
            ax[spec].set_xticklabels(['2000', '2005', '2010'])
            ax[spec].set_ylim([-0.1, 2.25])

    plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex  # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    rc('text', usetex=True)

    plt.tight_layout()
    plt.savefig(output_path)
    #plt.show()


def build_table_dispersion(df, input_path, output_path, overleaf_path):

    df['month'] = pd.to_datetime(df['month'])
    df_mean = df.groupby(['month'], as_index=False)['day_rate'].mean()
    df_mean.columns = ['month', 'day_rate_mean']
    df = df.merge(df_mean, on='month')
    df['day_rate_demean'] = df['day_rate'] - df['day_rate_mean']
    df['year'] = df['month'].dt.year
    df['month'] = df['month'].dt.month

    df['mri_2'] = df['mri'] * df['mri']
    df['value_2'] = df['value'] * df['value']
    df['water_depth_2'] = df['water_depth'] * df['water_depth']
    df['tau_2'] = df['tau'] * df['tau']
    df['dist_2'] = df['dist'] * df['dist']

    df['mri_3'] = df['mri'] * df['mri'] * df['mri']
    df['value_3'] = df['value'] * df['value'] * df['value']
    df['water_depth_3'] = df['water_depth'] * df['water_depth'] * df['water_depth']
    df['tau_3'] = df['tau'] * df['tau'] * df['tau']
    df['dist_3'] = df['dist'] * df['dist'] * df['dist']

    # Do the regression
    formulas = dict()
    formulas['agg'] = """
        day_rate_demean
            ~ mri
            + water_depth
            + value
            + gas
            + tau
            + dist
            + n_available_low
            + n_available_mid
            + n_available_high
            + C(operator)
            + C(contractor)
            + C(year)
            + C(month)
            + C(spec)
            + C(spec):mri
            + C(spec):value
            + C(spec):water_depth
            + C(spec):tau
            + C(spec):dist
    """
    formulas['disagg'] = """
        day_rate_demean
            ~ mri
            + water_depth
            + value
            + gas
            + tau
            + dist
            + n_available_low
            + n_available_mid
            + n_available_high
            + C(operator)
            + C(contractor)
            + C(year)
            + C(month)
            + C(max_wd)
            + C(max_wd):dist
            + C(max_wd):mri
            + C(max_wd):value
            + C(max_wd):water_depth
            + C(max_wd):tau
            + C(max_wd):dist
    """
    formulas['agg_2'] = """
        day_rate_demean
            ~ mri
            + mri_2
            + water_depth
            + water_depth_2
            + value
            + value_2
            + tau
            + tau_2
            + dist
            + dist_2
            + gas
            + n_available_low
            + n_available_mid
            + n_available_high
            + gas:gas
            + n_available_low:n_available_low
            + n_available_mid:n_available_mid
            + n_available_high:n_available_high
            + C(operator)
            + C(contractor)
            + C(year)
            + C(month)
            + C(spec)
            + C(spec):mri
            + C(spec):value
            + C(spec):water_depth
            + C(spec):tau
            + C(spec):dist
            + C(spec):mri_2
            + C(spec):value_2
            + C(spec):water_depth_2
            + C(spec):tau_2
            + C(spec):dist_2
    """
    formulas['disagg_2'] = """
        day_rate_demean
            ~ mri
            + mri_2
            + water_depth
            + water_depth_2
            + value
            + value_2
            + tau
            + tau_2
            + dist
            + dist_2
            + gas
            + n_available_low
            + n_available_mid
            + n_available_high
            + gas:gas
            + n_available_low:n_available_low
            + n_available_mid:n_available_mid
            + n_available_high:n_available_high
            + C(operator)
            + C(contractor)
            + C(year)
            + C(month)
            + C(max_wd)
            + C(max_wd):mri
            + C(max_wd):value
            + C(max_wd):water_depth
            + C(max_wd):tau
            + C(max_wd):dist
            + C(max_wd):mri_2
            + C(max_wd):value_2
            + C(max_wd):water_depth_2
            + C(max_wd):tau_2
            + C(max_wd):dist_2
    """

    formulas['agg_3'] = """
        day_rate_demean
            ~ mri
            + mri_2
            + mri_3
            + water_depth
            + water_depth_2
            + water_depth_3
            + value
            + value_2
            + value_3
            + tau
            + tau_2
            + tau_3
            + dist
            + dist_2
            + dist_3
            + gas
            + n_available_low
            + n_available_mid
            + n_available_high
            + gas:gas
            + n_available_low:n_available_low
            + n_available_mid:n_available_mid
            + n_available_high:n_available_high
            + gas:gas:gas
            + n_available_low:n_available_low:n_available_low
            + n_available_mid:n_available_mid:n_available_mid
            + n_available_high:n_available_high:n_available_high
            + C(operator)
            + C(contractor)
            + C(year)
            + C(month)
            + C(spec)
            + C(spec):mri
            + C(spec):value
            + C(spec):water_depth
            + C(spec):tau
            + C(spec):dist
            + C(spec):mri_2
            + C(spec):value_2
            + C(spec):water_depth_2
            + C(spec):tau_2
            + C(spec):dist_2
            + C(spec):mri_3
            + C(spec):value_3
            + C(spec):water_depth_3
            + C(spec):tau_3
            + C(spec):dist_3
            
      """
    formulas['disagg_3'] = """
        day_rate_demean
            ~ mri
            + mri_2
            + mri_3
            + water_depth
            + water_depth_2
            + water_depth_3
            + value
            + value_2
            + value_3
            + tau
            + tau_2
            + tau_3
            + dist
            + dist_2
            + dist_3
            + gas
            + n_available_low
            + n_available_mid
            + n_available_high
            + gas:gas
            + n_available_low:n_available_low
            + n_available_mid:n_available_mid
            + n_available_high:n_available_high
            + gas:gas:gas
            + n_available_low:n_available_low:n_available_low
            + n_available_mid:n_available_mid:n_available_mid
            + n_available_high:n_available_high:n_available_high
            + C(operator)
            + C(contractor)
            + C(year)
            + C(month)
            + C(max_wd)
            + C(max_wd):mri
            + C(max_wd):value
            + C(max_wd):water_depth
            + C(max_wd):tau
            + C(max_wd):dist
            + C(max_wd):mri_2
            + C(max_wd):value_2
            + C(max_wd):water_depth_2
            + C(max_wd):tau_2
            + C(max_wd):dist_2
            + C(max_wd):mri_3
            + C(max_wd):value_3
            + C(max_wd):water_depth_3
            + C(max_wd):tau_3
            + C(max_wd):dist_3
      """

    # Get tex
    tex_input = dict()
    for f in formulas:
        model = smf.ols(
            formula=formulas[f],
            data=df
        )
        results = copy.deepcopy(model).fit()
        tex_input['r2_' + f] = round(1 - results.rsquared, 2)
        tex_input['p_tilde_' + f] = round(np.std(results.resid * 1000, 0))
        tex_input['p_hat_' + f] = round(np.std(df['day_rate_demean'] * 1000, 0))

    for i in tex_input:
        with open(overleaf_path + f'parameters/parameter_{i}.tex',
                  'w') as fout:
            fout.write(f"{tex_input[i]}")

    with open(input_path, 'r') as f:
        tex = f.read()
        output = tex.format(**tex_input)
    with open(output_path, 'w') as f:
        f.write(output)


def build_table_match(params_match, ci_lower_match, ci_upper_match, delta,
                      av_price, input_path, output_path, overleaf_path):
    tex_input = dict()
    for i in params_match:
        tex_input[f'param_{i}'] = round(params_match[i] * 1000, 1)
        tex_input[f'param_{i}_prop'] = round(params_match[i] / av_price, 2)
        if i == 'm_1_mid':
            tex_input[f'param_{i}_prop'] = round(params_match[i] / av_price, 3)

        if ci_lower_match == None:
            tex_input[f'ci_lower_{i}'] = None
            tex_input[f'ci_upper_{i}'] = None
            tex_input[f'ci_lower_{i}_prop'] = None
            tex_input[f'ci_upper_{i}_prop'] = None
        else:
            tex_input[f'ci_lower_{i}'] = round(ci_lower_match[i] * 1000, 1)
            tex_input[f'ci_upper_{i}'] = round(ci_upper_match[i] * 1000, 1)
            tex_input[f'ci_lower_{i}_prop'] = round(ci_lower_match[i] / av_price, 2)
            tex_input[f'ci_upper_{i}_prop'] = round(ci_upper_match[i] / av_price, 2)
            if i == 'm_1_mid':
                tex_input[f'ci_lower_{i}_prop'] = round(ci_lower_match[i] / av_price, 3)
                tex_input[f'ci_upper_{i}_prop'] = round(ci_upper_match[i] / av_price, 3)
    tex_input['param_delta'] = round(delta, 2)

    with open(input_path, 'r') as f:
        tex = f.read()
        output = tex.format(**tex_input)
    with open(output_path, 'w') as f:
        f.write(output)

    # Write the parameters
    params_match['av_price'] = av_price
    params_match['pass_through'] = 'nan'
    tex_input['param_pass_through'] = round(30 * tex_input['param_m_2'] / 1000, 3)
    tex_input['param_pass_through_prop'] = 1.0
    tex_input['param_av_price'] = int(round(av_price * 1000, 0))
    tex_input['param_av_price_prop'] = 1.0
    for i in params_match:
        for k in ['', '_prop']:
            item = tex_input[f'param_{i}{k}']
            param_name = overleaf_path + f"parameters/{i}{k}.tex"
            with open(param_name, "wb") as f:
                f.write((str(item)).encode())


def build_table_sorting(df, input_path, output_path):
    # SET FORMULAS FOR REGRESSIONS ------------------------------------------------------
    formulas_by_reg = dict()
    summary_to_tex = dict()
    ci_by_reg = dict()

    # Set formulas
    formulas_by_reg[0] = 'mri ~ C(spec) + C(spec):C(boom) - 1'
    formulas_by_reg[1] = (
        "mri ~ C(spec) + C(spec):C(boom) + value -1"
    )
    formulas_by_reg[2] = (
        "mri ~ C(spec) + C(spec):C(boom) + value + tau -1"
    )

    coefs_to_tex_names = {
        'C(spec)[low]': 'low',
        'C(spec)[mid]': 'mid',
        'C(spec)[high]': 'high',
        'C(spec)[low]:C(boom)[T.True]': 'low_boom',
        'C(spec)[mid]:C(boom)[T.True]': 'mid_boom',
        'C(spec)[high]:C(boom)[T.True]': 'high_boom',
        'gas:value': 'value',
        'tau': 'total_days_description',
        'reneg': 'reneg'
    }

    # DO REGRESSIONS --------------------------------------------------------------------
    # Do regressions
    for i in formulas_by_reg:
        reg = smf.ols(
            formula=formulas_by_reg[i],
            data=df
        ).fit(cov_type='HC0')
        summary_to_tex[f"r_{i}"] = round(reg.rsquared, 2)
        summary_to_tex[f"n_{i}"] = int(reg.nobs)
        ci_by_reg[i] = reg.conf_int(alpha=0.05, cols=None)
        for k in coefs_to_tex_names:
            try:
                summary_to_tex[f"{coefs_to_tex_names[k]}_{i}"] = str(
                    round(reg.params[k], 3))
                summary_to_tex[f"se_{coefs_to_tex_names[k]}_{i}"] = str(
                    round(reg.HC0_se[k], 3))

                p_value = reg.pvalues[k]
                # Add stars to the coefficients in the table
                if p_value < 0.01:
                    summary_to_tex[f"{coefs_to_tex_names[k]}_{i}"] = \
                        summary_to_tex[f"{coefs_to_tex_names[k]}_{i}"] + "***"
                if (p_value >= 0.01) & (p_value < 0.05):
                    summary_to_tex[f"{coefs_to_tex_names[k]}_{i}"] = \
                        summary_to_tex[f"{coefs_to_tex_names[k]}_{i}"] + "**"
                if (p_value >= 0.05) & (p_value < 0.1):
                    summary_to_tex[f"{coefs_to_tex_names[k]}_{i}"] = \
                        summary_to_tex[f"{coefs_to_tex_names[k]}_{i}"] + "*"
            except:
                pass
                #print(f"{k} not in regression {i}")

    # PRODUCE THE TABLE -----------------------------------------------------------------
    with open(input_path, 'r') as f:
        tex = f.read()
        output = tex.format(**summary_to_tex)

    with open(output_path, 'w') as f:
        f.write(output)

    return ci_by_reg


def smarter_round(sig):
    def rounder(x):
        offset = sig - floor(log10(abs(x)))
        initial_result = round(x, offset)
        if str(initial_result)[-1] == '5' and initial_result == x:
            return round(x, offset - 2)
        else:
            return round(x, offset - 1)
    return rounder


def get_parameters(df_wells_merged, overleaf_path):
    df_wells_merged['well_duration'] = (
            df_wells_merged['depth_date'] - df_wells_merged['spud_date']).dt.days
    well_drilling_duration_mean = round(df_wells_merged['well_duration'].mean(), 1)
    well_drilling_duration_75 = round(df_wells_merged['well_duration'].quantile(0.75), 1)

    with open(
        overleaf_path + f'parameters/parameter_well_drilling_duration_mean.tex',
        'w') as fout:
            fout.write(f"{well_drilling_duration_mean}")

    with open(
        overleaf_path + f'parameters/parameter_well_drilling_duration_75.tex',
        'w') as fout:
            fout.write(f"{well_drilling_duration_75}")


def get_targeting_explanation(params, df_state, match_values_by_tau_spec,
                              surplus_grid, shares, overleaf_path):
    """ Get the explanation of the targeting parameters.

    Args:
        params:
        df_state:

    Returns:

    """
    # Build the surplus values to interpolate
    (
        surplus_values_by_tau_spec,
        well_outside_option_by_tau_spec
    ) = surplus.build_fast_surplus(
        match_values_by_tau_spec,
        params,
        surplus_grid
    )

    state = df_state.loc[df_state['date'] == '2005-01-01']
    mri_with_state = np.array(
        [
            2.0,
            state['g'].iloc[0],
            state['n_l'].iloc[0],
            state['n_m'].iloc[0],
            state['n_h'].iloc[0],
        ]
    )
    state = np.array([
            state['g'].iloc[0],
            state['n_l'].iloc[0],
            state['n_m'].iloc[0],
            state['n_h'].iloc[0],
    ])

    surp_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        surp_by_spec[spec] = eval_linear(
            surplus_grid,
            surplus_values_by_tau_spec[(3, spec)],
            mri_with_state)

        surp_by_spec[spec] = (1 - params['delta']) * surp_by_spec[spec]

    # Get the prob of a well matching
    shares_2005 = np.array([
        shares.loc['2005-01-01']['low'],
        shares.loc['2005-01-01']['mid'],
        shares.loc['2005-01-01']['high']
    ])
    params['a_0'] = np.array([1.0, 1.0, 1.0])
    params['a_1'] = np.array([
        params['a_1_low'],
        params['a_1_mid'],
        params['a_1_high']
    ])
    q_x_2005 = model_objects.q_x(
        params['a_0'], params['a_1'], params['d_0'], params['d_1'],
        state, shares_2005)

    # Needs to be in the below form...
    ev = np.array([
        [q_x_2005[0] * surp_by_spec['low']],
        [q_x_2005[1] * surp_by_spec['mid']],
        [q_x_2005[2] * surp_by_spec['high']],
    ]).T

    tex_input = dict()
    a = model_objects.get_mnl(
        gamma=0,
        gamma_negative=params['gamma_negative'],
        ev=ev,
        state=state
    )[0]
    tex_input['random_search'] = round(a[2], 2)
    a = model_objects.get_mnl(
        gamma=params['gamma'],
        gamma_negative=params['gamma_negative'],
        ev=ev,
        state=state
    )[0]
    tex_input['partially_directed_search'] = round(a[2], 2)

    # PRODUCE THE TABLE -----------------------------------------------------------------
    with open('./src/tex/table_search.tex', 'r') as f:
        tex = f.read()
        output = tex.format(**tex_input)

    with open(overleaf_path + './tables/table_search.tex', 'w') as f:
        f.write(output)

    for p in tex_input:
        with open(overleaf_path + f'./parameters/parameter_tech_{p}.tex', 'w') as f:
            f.write(str(tex_input[p]))


def make_table_mismatch(df_contracts, df_state, params, overleaf_path):
    match_value = {
        'm_0_low': params['m_0_low'],
        'm_1_low': params['m_1_low'],
        'm_0_mid': params['m_0_mid'],
        'm_1_mid': params['m_1_mid'],
        'm_0_high': params['m_0_high'],
        'm_1_high': params['m_1_high'],
    }

    # Get vacant rigs (use the available rigs) and the data on new contracts
    df_contracts['id'] = range(len(df_contracts))

    # Get the newly contracted wells in each month;
    # reorder to get a `new' dataset with available rigs
    # (with the same number of matches for each rig type)

    # First, get the number of matches by each type/month
    df_contracts['tau_cut'] = df_contracts['duration'].apply(lambda x: int(np.round(x / 30.0)))
    df_contracts.loc[df_contracts['tau_cut'].isin([0, 1]), 'tau_cut'] = 2  # Ensure min. contract duration is at least 2 months (to be consistent with the model)

    for spec in ['low', 'mid', 'high']:
        df_contracts[f'match_value_{spec}'] = \
            (
                match_value[f"m_0_{spec}"]
                + match_value[f"m_1_{spec}"] * df_contracts['mri']
            ) * df_contracts['tau_cut'] * 30

    # Then, get allocations with perfectly ranked matches
    total_surplus_by_date = dict()
    df_contracts['match_value_reallocate'] = 0.0
    df_contracts['negative'] = False
    df_contracts['rig_new'] = np.nan

    # Array to track changes in matches over time periods
    # Note that left most digit is both matches with 0 and 1 period remaining (since a rig is available if there is 1 period left)
    change_in_matches_by_spec = {
        'low': [0.0] * 60,
        'mid': [0.0] * 60,
        'high': [0.0] * 60
    }
    n_by_spec_init_all = list()
    for i in df_contracts['month'].sort_values().unique():
        # print("Init at", i, change_in_matches_by_spec)
        state_i = df_state[df_state['date'] == i]
        df_contracts_i = df_contracts[df_contracts['month'] == i]

        n_by_spec_init = dict()
        n_by_spec = dict()
        cost_by_spec = dict()
        for spec in ['low', 'mid', 'high']:
            n_by_spec_init[spec] = int(round(state_i[f'n_unemployed_{spec}'], 0)) + len(df_contracts_i[df_contracts_i['spec'] == spec].values)  # n available empirically
            n_by_spec_init[spec] = (
                n_by_spec_init[spec]
                - int(np.sum(change_in_matches_by_spec[spec]))  # Remove change in matched rigs from n_available (i.e. all matches with >1 periods)
            )  # Adjust for new unemployed rigs
            n_by_spec[spec] = max(0, n_by_spec_init[spec])

            # Each row is a rig, each column is a well
            cost_by_spec[spec] = np.tile(-df_contracts_i[f'match_value_{spec}'].values,
                                         (n_by_spec[spec], 1))

        n_by_spec_init_all.append(n_by_spec_init)

        # Each row is a rig, each column is a well
        # Each row is a rig, each column is a well
        cost_all = np.vstack([cost_by_spec['low'], cost_by_spec['mid'], cost_by_spec['high']])

        # Validation check: ensure there are enough rigs for all wells
        total_rigs = sum(n_by_spec.values())
        total_wells = len(df_contracts_i)
        if total_rigs < total_wells:
            print(f"Warning: Not enough rigs ({total_rigs}) to match all wells ({total_wells}) on {i}")
            assert 0

        row_ind, col_ind = linear_sum_assignment(cost_all)
        total_surplus_by_date[i] = -1 * cost_all[row_ind, col_ind].sum()

        # Create initial dataframe with match values and their column indices
        a = pd.DataFrame(np.vstack([-1 * cost_all[row_ind, col_ind], col_ind])).T.sort_values(by=[1])
        df_contracts.loc[df_contracts['month'] == i, 'match_value_reallocate'] = a[0].values

        # Check if we every get negative rigs if unbounded
        if any(value < 0 for value in n_by_spec_init.values()):
            df_contracts.loc[df_contracts['month'] == i, 'negative'] = True

        # Get which rig it is now matched to
        for spec in ['low', 'mid', 'high']:
            df_contracts.loc[df_contracts['match_value_reallocate'] == df_contracts[f'match_value_{spec}'], 'rig_new'] = spec

        # Find changes in # rigs of each type
        current_period = df_contracts.loc[df_contracts['month'] == i]

        change_in_matches_by_spec_at_i = dict()  # This is for testing
        for spec in ['low', 'mid', 'high']:
            for tau in range(1, 60):
                delta = len(current_period[(current_period['rig_new'] == spec) & (
                        current_period['tau_cut'] == tau)]) \
                        - len(current_period[(current_period['spec'] == spec) & (
                        current_period['tau_cut'] == tau)])

                change_in_matches_by_spec_at_i[(spec, tau)] = delta
                change_in_matches_by_spec[spec][tau - 1] += copy.deepcopy(delta) # if == 1, then add a rig in this position

        # print("Change at i:", i, change_in_matches_by_spec_at_i)
        # print("Updated at i:", i, change_in_matches_by_spec)

        # Move list along once
        for spec in ['low', 'mid', 'high']:
            change_in_matches_by_spec[spec] = change_in_matches_by_spec[spec][1:] + [0]

        print(n_by_spec)

    # GET THE DIFFERENCES INCLUDING THE MATCH VALUE -------------------------------------
    # Note: rig changes, mri is the same in each row
    # Setup the values
    for spec in ['low', 'mid', 'high']:
        mri_empirical = df_contracts.loc[df_contracts['spec'] == spec, 'mri']
        df_contracts.loc[df_contracts['spec'] == spec, 'match_value_empirical'] = \
            (match_value[f"m_0_{spec}"] + match_value[f"m_1_{spec}"] * mri_empirical) * \
            df_contracts['tau_cut'] * 30

    df_contracts['diff_match'] = df_contracts['match_value_reallocate'] - df_contracts['match_value_empirical']

    # %% DO THE TESTS ----------------------------------------------------------------------
    test_by_type = dict()
    test_by_type['diff'] = ttest_ind(
        df_contracts.loc[(df_contracts['boom'] == False), 'diff_match'],
        df_contracts.loc[(df_contracts['boom'] == True), 'diff_match'],
        equal_var=False,
        alternative='greater'
    )
    diff_coef = (
            df_contracts.loc[(df_contracts['boom'] == False), 'diff_match'].mean()
            - df_contracts.loc[(df_contracts['boom'] == True), 'diff_match'].mean()
    )

    # Get the differences in bust
    test_by_type['bust'] = ttest_ind(
        df_contracts.loc[df_contracts['boom'] == False, 'match_value_reallocate'],
        df_contracts.loc[df_contracts['boom'] == False, 'match_value_empirical'],
        equal_var=False,
        alternative='greater'
    )
    bust_coef = (
            df_contracts.loc[
                df_contracts['boom'] == False, 'match_value_reallocate'].mean()
            - df_contracts.loc[
                df_contracts['boom'] == False, 'match_value_empirical'].mean()
    )

    # Get the differences in boom
    test_by_type['boom'] = ttest_ind(
        df_contracts.loc[df_contracts['boom'] == True, 'match_value_reallocate'],
        df_contracts.loc[df_contracts['boom'] == True, 'match_value_empirical'],
        equal_var=False,
        alternative='greater'
    )
    boom_coef = (
            df_contracts.loc[
                df_contracts['boom'] == True, 'match_value_reallocate'].mean()
            - df_contracts.loc[
                df_contracts['boom'] == True, 'match_value_empirical'].mean()
    )

    # MAKE THE COEFS --------------------------------------------------------------------
    summary_to_tex = {
        'diff_bust': round(bust_coef, 3),
        'diff_boom': round(boom_coef, 3),
        'diff_diff': round(diff_coef, 3)
    }
    for k in ['diff', 'bust', 'boom']:
        summary_to_tex[f"p_{k}"] = str(round(test_by_type[k].pvalue, 3))

        p_value = test_by_type[k].pvalue
        # Add stars to the coefficients in the table
        if p_value < 0.01:
            summary_to_tex[f"p_{k}"] = summary_to_tex[f"p_{k}"] + "***"
        if (p_value >= 0.01) & (p_value < 0.05):
            summary_to_tex[f"p_{k}"] = summary_to_tex[f"p_{k}"] + "**"
        if (p_value >= 0.05) & (p_value < 0.1):
            summary_to_tex[f"p_{k}"] = summary_to_tex[f"p_{k}"] + "*"

    # PRODUCE THE TABLE -----------------------------------------------------------------
    with open('./src/tex/table_mismatch.tex', 'r') as f:
        tex = f.read()
        output = tex.format(**summary_to_tex)

    with open(overleaf_path + 'tables/table_mismatch.tex', 'w') as f:
        f.write(output)

    # GET THE PARAMETERS ----------------------------------------------------------------
    for k in ['diff_bust', 'diff_boom', 'diff_diff']:
        with open(overleaf_path + f'parameters/{k}.tex', 'w') as fout:
            fout.write(f"{summary_to_tex[k]}")


def get_duration_discussion(df_contracts, overleaf_path):

    contracts = copy.deepcopy(df_contracts)
    contracts['duration_prev'] = contracts['duration'].shift(1)
    contracts['tau_prev'] = contracts['tau'].shift(1)
    contracts['year'] = pd.to_datetime(contracts['fixture_date']).dt.year
    contracts['diff_tau'] = contracts['tau'] - contracts['tau_prev']
    contracts['diff_duration'] = contracts['duration'] - contracts['duration_prev']

    # %%
    n_no_tau_change = ((contracts['diff_tau'] == 0.0) & (contracts['reneg'] == 1)).sum()
    pct_no_tau_change_from_reneg = n_no_tau_change / (contracts['reneg'] == 1).sum()
    pct_tau_change_all = ((contracts['reneg'] == 1).sum() - n_no_tau_change) / len(
        contracts)

    # %% WRITE OUTPUT ------------------------------------------------------------------------
    with open(overleaf_path + f'parameters/parameter_pct_no_tau_change_from_reneg.tex',
              'w') as fout:
        fout.write(f"{round(pct_no_tau_change_from_reneg * 100, 1)}\%")

    with open(overleaf_path + f'parameters/parameter_pct_tau_change_all.tex',
              'w') as fout:
        fout.write(f"{round(pct_tau_change_all * 100, 1)}\%")

    with open(overleaf_path + f'parameters/parameter_pct_no_tau_change_all.tex',
              'w') as fout:
        fout.write(f"{round((1 - pct_tau_change_all) * 100, 1)}\%")


def get_data_section_parameters(df_contracts, overleaf_path):

    # Get HHIs
    shares_by_i = dict()
    hhi_by_i = dict()
    for i in ['contractor', 'operator']:
        shares_by_i[i] = \
            df_contracts.groupby(i).apply(len) / len(df_contracts)
        hhi_by_i[i] = (shares_by_i[i] ** 2).sum() * 10000

        with open(overleaf_path + f'parameters/parameter_hhi_{i}.tex',
                  'w') as fout:
            fout.write(f"{int(round(hhi_by_i[i], 0))}")

    # Get the turnkey proportion
    prop_turnkey = df_contracts['turnkey'].mean()
    prop_adti = len(df_contracts[(
        (df_contracts['turnkey'] == 1.0)
        & (df_contracts['operator'] == 'ADTI')
    )]) / len(df_contracts[df_contracts['turnkey'] == 1.0])

    with open(overleaf_path + f'parameters/parameter_turnkey.tex',
              'w') as fout:
        fout.write(f"{round(prop_turnkey * 100, 1)}")

    with open(overleaf_path + f'parameters/parameter_adti.tex',
              'w') as fout:
        fout.write(f"{round(prop_adti * 100, 1)}")


def get_relationships(df_contracts, overleaf_path):
    df_contracts['relationship'] = df_contracts.groupby(
        ['rig_name', 'operator']).ngroup()

    indicator_past_years = list()
    for row in df_contracts.iterrows():
        contracts_past = df_contracts[
            (df_contracts['relationship'] == row[1].relationship)
            & df_contracts['fixture_date'].between(
                row[1].fixture_date - pd.Timedelta(365 * 2, unit='d'),
                row[1].fixture_date - pd.Timedelta(1, unit='d'))]
        indicator_past_years.append(
            (len(contracts_past) >= 1)
        )

    df_contracts['past_relationship'] = np.array(indicator_past_years)

    a = df_contracts.loc[
        df_contracts['type'] == 'New mutual', 'past_relationship'].mean()

    with open(overleaf_path + f'parameters/parameter_relationship.tex',
              'w') as fout:
        fout.write(f"{int(round((1 - a) * 100, 0))}")


def get_utilization_table(df_state, overleaf_path):
    utilization = (
        df_state
        .groupby(['boom'])
        [['utilization_low', 'utilization_mid', 'utilization_high']]
    ).mean().unstack()

    tex_input = dict()
    for boom in [True, False]:
        for spec in ['low', 'mid', 'high']:
            tex_input[f'{spec}_{boom}'] = round(
                utilization[(f'utilization_{spec}', boom)], 3
            )

    # PRODUCE THE TABLE -----------------------------------------------------------------
    with open('./src/tex/table_utilization.tex', 'r') as f:
        tex = f.read()
        output = tex.format(**tex_input)

    with open(overleaf_path + 'tables/table_utilization.tex', 'w') as f:
        f.write(output)


def get_data_construction(df_wells_with_mri, df_contracts_init, contracts_with_map,
                          df_merged_wells_no_impute, df_merged_contracts_no_impute,
                          df_contracts, overleaf_path):

    # Contract cleaning
    df_contracts_init = df_contracts_init[
        (df_contracts_init['fixture_date'] >= pd.to_datetime('2000-01-01'))
        & (df_contracts_init['fixture_date'] <= pd.to_datetime('2009-12-31'))
    ]
    n_initial_contracts = len(df_contracts_init)
    with open(overleaf_path + f'parameters/parameter_n_contracts_initial.tex',
              'w') as fout:
        fout.write(f"{int(n_initial_contracts)}")

    contracts_with_map = contracts_with_map[
        (contracts_with_map['fixture_date'] >= pd.to_datetime('2000-01-01'))
        & (contracts_with_map['fixture_date'] <= pd.to_datetime('2009-12-31'))
    ]
    n_contracts_with_map = len(contracts_with_map)
    with open(overleaf_path + f'parameters/parameter_n_contracts_with_map.tex',
              'w') as fout:
        fout.write(f"{int(n_contracts_with_map)}")

    # Well dataset cleaning
    df_wells_with_mri = df_wells_with_mri[
        ((df_wells_with_mri['spud_date'] >= pd.to_datetime('2000-01-01'))
         & (df_wells_with_mri['spud_date'] <= pd.to_datetime('2009-12-31')))
        | ((df_wells_with_mri['depth_date'] >= pd.to_datetime('2000-01-01'))
           & (df_wells_with_mri['depth_date'] <= pd.to_datetime('2009-12-31')))
    ]
    n_wells = len(df_wells_with_mri)
    with open(overleaf_path + f'parameters/parameter_n_wells_with_map.tex',
              'w') as fout:
        fout.write(f"{int(n_wells)}")

    # Merged contracts and wells
    df_merged_wells_no_impute = df_merged_wells_no_impute[
        ((df_merged_wells_no_impute['spud_date'] >= pd.to_datetime('2000-01-01'))
         & (df_merged_wells_no_impute['spud_date'] <= pd.to_datetime('2009-12-31')))
        | ((df_merged_wells_no_impute['depth_date'] >= pd.to_datetime('2000-01-01'))
           & (df_merged_wells_no_impute['depth_date'] <= pd.to_datetime('2009-12-31')))
    ]
    n_wells_merged_no_impute = len(df_merged_wells_no_impute)
    with open(overleaf_path + f'parameters/parameter_n_wells_merged_no_impute.tex',
              'w') as fout:
        fout.write(f"{int(n_wells_merged_no_impute)}")

    df_merged_contracts_no_impute = df_merged_contracts_no_impute[
        (df_merged_contracts_no_impute['fixture_date'] >= pd.to_datetime('2000-01-01'))
        & (df_merged_contracts_no_impute['fixture_date'] <= pd.to_datetime('2009-12-31'))
    ]
    n_contracts_merged_no_impute = len(df_merged_contracts_no_impute)
    with open(overleaf_path + f'parameters/parameter_n_contracts_merged_no_impute.tex',
              'w') as fout:
        fout.write(f"{int(n_contracts_merged_no_impute)}")

    # N of contracts
    n_contracts = len(df_contracts)
    with open(overleaf_path + f'parameters/parameter_n_contracts_total.tex',
              'w') as fout:
        fout.write(f"{int(n_contracts)}")

    # Impute (difference)
    with open(overleaf_path + f'parameters/parameter_n_contracts_imputed.tex',
              'w') as fout:
        fout.write(f"{int(n_contracts - n_contracts_merged_no_impute)}")

    # Percent matched
    with open(overleaf_path + f'parameters/parameter_n_contracts_pct_matched.tex',
              'w') as fout:
        fout.write(f"{int(round(100 * n_contracts/n_contracts_with_map, 0))}\%")

    with open(overleaf_path + f'parameters/parameter_n_wells_pct_matched.tex',
              'w') as fout:
        fout.write(f"{int(round(100 * n_wells_merged_no_impute/n_wells, 0))}\%")


def get_out_of_sample(df_state, df_contracts, data, params, weights, output_path):

    # READ IN THE INPUTS -------------------------------------------------------------
    #data, delta, rho, c, weights = utils.read_in_data(use_myopic=True)
    #data['g_data'] = df_state[['date', 'g']]

    # Pre-compute some important quantities
    params['p_4'] = 1.0 - params['p_3'] - params['p_2']
    params['a_0'] = np.array([1.0, 1.0, 1.0])
    params['a_1'] = np.array([params['a_1_low'], params['a_1_mid'], params['a_1_high']])
    params['denom'] = (
            scipy.stats.norm.cdf(
                2.15,
                loc=params['mu_0'],
                scale=params['sigma_0']
            )
            - scipy.stats.norm.cdf(
        0,
        loc=params['mu_0'],
        scale=params['sigma_0']
    )
    )

    # DO THE SIMULATION --------------------------------------------------------------
    args_config = {
        'params_fixed': {},
        'weights': weights,
        'mri_max': 2.15,
        'verbose': False,
        'verbose_output': True,
        'value_zero': False,
        'entry_prob_by_tau_by_ym': None,
        'tau_multiplier': 1,
        'x_names': list(params.keys()),
        'well_target': True
    }
    args_by_names = {**args_config, **data}
    results_by_sim = dict()


    # DO SIMULATION -------------------------------------------------------------------
    results_by_sim['benchmark'] = simulation.run_steps(
        x=list(params.values()),
        args_by_names=args_by_names
    )

    df_moments = pd.DataFrame(results_by_sim['benchmark']['moments_by_state'])

    df_moments['month'] = df_state['month']

    # Graph utilization
    date_mask = ((pd.to_datetime(df_moments['month']).dt.year <= 2013)
        & (
            (pd.to_datetime(df_moments['month']).dt.year >= 2011)
            | (
                    (pd.to_datetime(df_moments['month']).dt.year == 2010)
                    & (pd.to_datetime(df_moments['month']).dt.month <= 4)
        )))

    #%% Graph prices
    a = data['state_data'][date_mask]
    b = df_moments[date_mask]

    sim_list = list()
    empirical_list = list()
    for spec in ['low', 'mid', 'high']:
        empirical_list.append(a[f'utilization_{spec}'].mean())
        sim_list.append(b[f'utilization_{spec}'].mean())

    # Graph prices
    b = df_moments[date_mask]

    sim_list = list()
    empirical_list = list()
    for spec in ['low', 'mid', 'high']:
        empirical_list.append(a[f'utilization_{spec}'].mean())
        sim_list.append(b[f'utilization_{spec}'].mean())

    df_contracts['fixture_date'] = pd.to_datetime(df_contracts['fixture_date'])

    # Look at prices
    for spec in ['low', 'mid', 'high']:
        df_contracts.loc[df_contracts['spec'] == spec, 'price_predicted'] \
            = results_by_sim['benchmark']['prices_new_by_spec'][f'price_{spec}']

    predict_price_by_spec = dict()
    dayrate_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        predict_price_by_spec[spec] = df_contracts.loc[
            (df_contracts['spec'] == spec) & (
                (pd.to_datetime(df_contracts['fixture_date']).dt.year <= 2014)
                & (
                        (pd.to_datetime(
                            df_contracts['fixture_date']).dt.year >= 2011)
                        | (
                                (pd.to_datetime(
                                    df_contracts['fixture_date']).dt.year == 2010)
                                & (pd.to_datetime(
                            df_contracts['fixture_date']).dt.month <= 4)
                        ))),
            'price_predicted'
        ].mean()

        dayrate_by_spec[spec] = df_contracts.loc[
            (df_contracts['spec'] == spec) & (
                (pd.to_datetime(df_contracts['fixture_date']).dt.year <= 2014)
                & (
                        (pd.to_datetime(
                            df_contracts['fixture_date']).dt.year >= 2011)
                        | (
                                (pd.to_datetime(
                                    df_contracts['fixture_date']).dt.year == 2010)
                                & (pd.to_datetime(
                            df_contracts['fixture_date']).dt.month <= 4)
                        ))),
            'day_rate'
        ].mean()

    diff_by_spec = dict()
    diff_by_spec[('mh', 'price_predicted')] = predict_price_by_spec['high'] - \
                                              predict_price_by_spec['mid']
    diff_by_spec[('lm', 'price_predicted')] = predict_price_by_spec['mid'] - \
                                              predict_price_by_spec['low']
    diff_by_spec[('mh', 'day_rate')] = dayrate_by_spec['high'] - dayrate_by_spec['mid']
    diff_by_spec[('lm', 'day_rate')] = dayrate_by_spec['mid'] - dayrate_by_spec['low']

    # DO GRAPH
    plt.rc('font', size=9)  # controls default text sizes
    plt.rc('axes', titlesize=11)  # fontsize of the axes title
    plt.rc('axes', labelsize=10)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=9)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=9)  # fontsize of the tick labels
    plt.rc('legend', fontsize=9)  # legend fontsize
    plt.rc('figure', titlesize=11)  # figure title size

    f = plt.figure(figsize=(7, 3.0))

    # Left panel
    ax0 = plt.subplot(1, 2, 1)
    ax0.scatter([1, 2, 3], empirical_list, marker="o", color='gray', label='Data')
    ax0.scatter([1, 2, 3], sim_list, marker='o', color='black', label='Simulation')

    ax0.set_yticks([0, 0.5, 1.0])
    ax0.set_ylim([-0.1, 1.0])
    ax0.set_xlim([0.8, 3.2])
    ax0.spines['top'].set_visible(False)
    ax0.spines['right'].set_visible(False)

    ax0.set_ylabel('Utilization')
    ax0.set_xticks([1, 2, 3])
    ax0.set_xticklabels(['Low', 'Mid', 'High'])
    ax0.set_xlabel('Rig Type')
    ax0.legend()

    # Right panel
    ax1 = plt.subplot(1, 2, 2)
    ax1.scatter(
        [1, 2],
        [diff_by_spec[('lm', 'day_rate')], diff_by_spec[('mh', 'day_rate')]],
        color='gray',
        # label='Data'
    )
    ax1.scatter(
        [1, 2],
        [diff_by_spec[('lm', 'price_predicted')],
         diff_by_spec[('mh', 'price_predicted')]],
        color='black',
        # label='Simulation'
    )

    ax1.set_ylim([-0.02, 0.06])
    ax1.set_xlim([0.8, 2.2])
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)

    ax1.set_ylabel('Difference in Prices')
    ax1.set_xticks([1, 2])
    ax1.set_xticklabels(['Mid - Low', 'High - Mid'])
    ax1.set_xlabel('Rig Type')

    # Titles
    ax0.set_title('(a) Utilization', pad=10)
    ax1.set_title('(b) Difference in Prices', pad=10)

    plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex  # Use latex
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
    ax0.legend(loc=3)
    rc('text', usetex=True)

    f.tight_layout(pad=3)
    # ax1.get_legend().remove()

    ax0.legend(loc=4)

    plt.savefig(
        output_path,
        bbox_inches='tight'
    )


def get_deepening_number(df_contracts, df_contracts_no_deepening, overleaf_path):

    n_contracts_diff = len(df_contracts) - len(df_contracts_no_deepening)

    with open(overleaf_path + f'parameters/parameter_n_contracts_remove_deepening.tex',
              'w') as fout:
        fout.write(f"{int(n_contracts_diff)}")

    with open(overleaf_path + f'parameters/parameter_pct_contracts_remove_deepening.tex',
              'w') as fout:
        fout.write(f"{round(100 * n_contracts_diff/len(df_contracts), 1)}")
